#include <torch/extension.h>

#include <cuda.h>
#include <cuda_runtime.h>

// #include <vector>


#define CUDA_1D_KERNEL_LOOP(i, n)                            \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
     i += blockDim.x * gridDim.x)
#define THREAD_PER_BLOCK 1024



namespace {
template <typename scalar_t>
__global__ void sparse_matmul_cuda_forward_kernel(
    const int num_threads,
    const torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> x, // B x M x D
    const torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> y,
    const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> index, // Count x 2
    torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> output) 
{

    const int Count = index.size(0); // cnt_per_batch

    CUDA_1D_KERNEL_LOOP(thd_idx, num_threads){
        // total number of thread = B * Count, hence
        // thd_idx = b * Count + count

        const int b = thd_idx / Count;
        const int cnt = thd_idx % Count;
        const int m = (int) index[cnt][0];
        const int n = (int) index[cnt][1];
        const int D = x.size(2);
        output[b][m][n] = 0; // initialize as 0
        for (int d = 0; d < D; d++){
            output[b][m][n] = output[b][m][n] + x[b][m][d] * y[b][n][d];
            // atomicAdd(
            //     &output[b][m][n],
            //     x[b][m][d] * y[b][n][d]
            // );
        }
    }    

}

template <typename scalar_t>
__global__ void sparse_matmul_cuda_backward_kernel(
    const bool transpose,
    const int num_threads,
    const torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> grad_output,
    torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> grad_x,
    const torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> y,
    const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> mask) 
{

    // if (transpose){
    //     grad_output = grad_output.transpose(1, 2);
    //     mask = mask.transpose(0, 1);
    // }

    const int M = grad_x.size(1);
    const int D = grad_x.size(2);
    const int N = y.size(1);
    CUDA_1D_KERNEL_LOOP(thd_idx, num_threads){
        // total number of thread = B * M * D, hence
        // thd_idx = b * (M * D) + m * D + d
        const int b = thd_idx / (M*D);
        const int m = (thd_idx % (M*D)) / D;
        const int d = thd_idx % D;
        for (int n = 0; n < N; n++){
            if (transpose){
                int mask_ = (int) mask[n][m];
                if (mask_ == 1)
                    grad_x[b][m][d] = grad_x[b][m][d] + grad_output[b][n][m] * y[b][n][d];
            }
            else{
                int mask_ = (int) mask[m][n];
                if (mask_ == 1)
                    grad_x[b][m][d] = grad_x[b][m][d] + grad_output[b][m][n] * y[b][n][d];
            }
        }
    }
}

} // namespace

void sparse_matmul_cuda_forward_launch(
    torch::Tensor x,        // B x M x D
    torch::Tensor y,        // B x N x D
    torch::Tensor index,    // Count x 2
    torch::Tensor output    // B x M x N
) {

    auto B = x.size(0);
    auto M = x.size(1); 
    auto N = y.size(1); 
    auto cnt_per_batch = index.size(0); // cnt_per_batch <= M*N

    // const int THREAD_PER_BLOCK = x.size(2);

    const int total_count  = B * cnt_per_batch; // each thread is responsiable for the inner product of two vectors
    const int block_count = (total_count + THREAD_PER_BLOCK - 1) / THREAD_PER_BLOCK;

    if (block_count > 0)
    {
        AT_DISPATCH_FLOATING_TYPES(x.type(), "sparse matmul cuda forward launch!", ([&] {
            sparse_matmul_cuda_forward_kernel<scalar_t><<<block_count, THREAD_PER_BLOCK>>>(
                total_count,
                x.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
                y.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
                index.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
                output.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>()
            );
        }));
    }
}

void sparse_matmul_cuda_backward_launch(
    torch::Tensor grad_output,  // B x M x N
    torch::Tensor grad_x,       // B x M x D
    torch::Tensor grad_y,       // B x N x D
    torch::Tensor x,            // B x M x D
    torch::Tensor y,            // B x N x D
    torch::Tensor mask)        // M x N
{

    auto B = x.size(0);
    auto D = x.size(2);
    auto M = x.size(1);
    auto N = y.size(1);

    // grad x
    int total_count  = B * M * D; 
    int block_count = (total_count + THREAD_PER_BLOCK - 1) / THREAD_PER_BLOCK;

    if (block_count > 0)
    {
        AT_DISPATCH_FLOATING_TYPES(x.type(), "sparse matmul cuda forward launch!", ([&] {
            sparse_matmul_cuda_backward_kernel<scalar_t><<<block_count, THREAD_PER_BLOCK>>>(
                false,
                total_count,
                grad_output.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
                grad_x.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
                y.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
                mask.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>()
            );
        }));
    }

    // grad y
    total_count  = B * N * D; 
    block_count = (total_count + THREAD_PER_BLOCK - 1) / THREAD_PER_BLOCK;
    
    if (block_count > 0)
    {
        AT_DISPATCH_FLOATING_TYPES(x.type(), "sparse matmul cuda forward launch!", ([&] {
            sparse_matmul_cuda_backward_kernel<scalar_t><<<block_count, THREAD_PER_BLOCK>>>(
                true,
                total_count,
                grad_output.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
                grad_y.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
                x.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
                mask.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>()
            );
        }));
    }

}
